ML pipelines: RunInference - OSS Image Object detection, OSS Image Captioning, OSS Image Classification#37186
ML pipelines: RunInference - OSS Image Object detection, OSS Image Captioning, OSS Image Classification#37186Amar3tto wants to merge 36 commits into
Conversation
Summary of ChangesHello @Amar3tto, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Apache Beam's machine learning capabilities by integrating a new PyTorch-based image object detection pipeline. The pipeline leverages the RunInference transform for efficient batched GPU inference with open-source TorchVision models, processing images from cloud storage and outputting structured detection results to BigQuery. This addition is complemented by a new performance benchmark and corresponding documentation, ensuring that the pipeline's efficiency and resource usage can be consistently monitored and evaluated. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Ignored Files
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #37186 +/- ##
=============================================
- Coverage 55.28% 36.33% -18.96%
Complexity 1676 1676
=============================================
Files 1067 1069 +2
Lines 167148 167178 +30
Branches 1208 1208
=============================================
- Hits 92415 60737 -31678
- Misses 72551 104259 +31708
Partials 2182 2182
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Assigning reviewers: R: @claudevdm for label python. Note: If you would like to opt out of this review, comment Available commands:
The PR bot will only process comments in the main thread (not review comments). |
|
@Abacn Could you please help with review? |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces three new ML inference pipelines for image classification, object detection, and image captioning using PyTorch, along with their corresponding benchmarks and documentation. The pipelines are well-structured and showcase advanced Beam features like RunInference with custom model handlers and stateful DoFns. My review focuses on improving scalability, robustness, and maintainability. I've identified a few key areas for improvement, including a scalability bottleneck in the data loading pipelines, several instances of broad exception handling that could mask errors, some potentially buggy logic, and a few copy-paste errors in the new documentation pages. Overall, this is a valuable contribution, and the suggested changes aim to make these examples more robust and easier to understand.
|
Reminder, please take a look at this pr: @claudevdm @liferoad @shunping |
|
Assigning new set of reviewers because Pr has gone too long without review. If you would like to opt out of this review, comment R: @jrmccluskey for label python. Available commands:
|
|
Reminder, please take a look at this pr: @jrmccluskey @damccorm |
|
Assigning new set of reviewers because Pr has gone too long without review. If you would like to opt out of this review, comment R: @shunping for label python. Available commands:
|
|
waiting on author |
a7ba9db to
439b5aa
Compare
|
Could you please fix the formatting failures? Also, please avoid rebasing when possible to avoid breaking GitHub's review features |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces several new PyTorch-based inference pipelines (image captioning, object detection, and image classification with adaptive batch sizing) along with their corresponding benchmarks, requirements, and performance documentation. The review feedback highlights critical bugs regarding shape mismatches during batching in the classification and object detection pipelines, a performance bottleneck from redundant image encoding in the CLIP model, an architectural flaw where model warmup runs on the submission client instead of the workers, and a robustness issue concerning guaranteed cleanup of Pub/Sub resources.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| del inference_args | ||
| del model_id |
There was a problem hiding this comment.
It is used instead of # pylint: disable=unused-argument
There was a problem hiding this comment.
I think we should probably use the annotation instead. This is more confusing IMO, whereas that has clear intent
|
/gemini review |
| del inference_args | ||
| del model_id |
There was a problem hiding this comment.
I think we should probably use the annotation instead. This is more confusing IMO, whereas that has clear intent
| self._batch_size = bs | ||
| self._inference_batch_size = bs |
There was a problem hiding this comment.
Why is this nondeterministic? Can't we set this up front based on what parameters a user passes in?
There was a problem hiding this comment.
Code Review
This pull request adds several new PyTorch-based inference pipelines and benchmarks to Apache Beam's Python SDK, covering image captioning (BLIP + CLIP), image object detection (Faster R-CNN ResNet-50), and image classification (EfficientNet-B0 with right-fitting), along with corresponding documentation and performance tracking configurations. The review feedback highlights several critical issues and improvement opportunities: a potential GPU OOM in the right-fitting handler due to runner-side batching which requires sub-batching in run_inference; hardcoded GPU resource hints that should be conditionally applied; potential mismatches when parsing Pub/Sub paths; race conditions from modifying shared pipeline_args in background threads; and opportunities to simplify device transfers using Hugging Face's built-in .to(device) method.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| class RightFittingPytorchModelHandlerTensor(PytorchModelHandlerTensor): | ||
| def __init__(self, batch_sizes_to_try, image_size, *args, **kwargs): | ||
| self._batch_sizes_to_try = batch_sizes_to_try | ||
| self._rightfit_image_size = image_size | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def load_model(self): | ||
| model = super().load_model() | ||
| last_err = None | ||
|
|
||
| for bs in self._batch_sizes_to_try: | ||
| try: | ||
| model_device = next(model.parameters()).device | ||
| dummy = torch.zeros( | ||
| (bs, 3, self._rightfit_image_size, self._rightfit_image_size), | ||
| dtype=torch.float32, | ||
| device=model_device) | ||
|
|
||
| with torch.no_grad(): | ||
| model(dummy) | ||
|
|
||
| self._batch_size = bs | ||
| self._inference_batch_size = bs | ||
| logging.info("Selected inference batch size: %s", bs) | ||
| return model | ||
| except RuntimeError as e: | ||
| last_err = e | ||
| logging.warning("Batch size %s failed during worker warmup: %s", bs, e) | ||
|
|
||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
|
|
||
| raise RuntimeError( | ||
| f"No valid inference batch size found from {self._batch_sizes_to_try}" | ||
| ) from last_err |
There was a problem hiding this comment.
The BatchElements transform is applied on the client side during pipeline construction using the initial maximum batch size (e.g., 64). Therefore, even if load_model dynamically selects a smaller batch size (e.g., 16) on the worker, the runner will still send batches of up to 64 elements to run_inference, which will cause GPU OOM. To fix this, RightFittingPytorchModelHandlerTensor must override run_inference to sub-batch the incoming batch to the dynamically selected self._inference_batch_size.
class RightFittingPytorchModelHandlerTensor(PytorchModelHandlerTensor):
def __init__(self, batch_sizes_to_try, image_size, *args, **kwargs):
self._batch_sizes_to_try = batch_sizes_to_try
self._rightfit_image_size = image_size
super().__init__(*args, **kwargs)
def load_model(self):
model = super().load_model()
last_err = None
for bs in self._batch_sizes_to_try:
try:
model_device = next(model.parameters()).device
dummy = torch.zeros(
(bs, 3, self._rightfit_image_size, self._rightfit_image_size),
dtype=torch.float32,
device=model_device)
with torch.no_grad():
model(dummy)
self._batch_size = bs
self._inference_batch_size = bs
logging.info("Selected inference batch size: %s", bs)
return model
except RuntimeError as e:
last_err = e
logging.warning("Batch size %s failed during worker warmup: %s", bs, e)
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise RuntimeError(
f"No valid inference batch size found from {self._batch_sizes_to_try}"
) from last_err
def run_inference(self, batch, model, inference_args=None):
size = self._inference_batch_size
sub_batches = [batch[i:i + size] for i in range(0, len(batch), size)]
results = []
for sub_batch in sub_batches:
results.extend(super().run_inference(sub_batch, model, inference_args))
return results| predictions = ( | ||
| to_infer | ||
| | 'Reshuffle' >> beam.Reshuffle() | ||
| | 'RunInference' >> RunInference( | ||
| KeyedModelHandler(model_handler)).with_resource_hints( | ||
| accelerator="type:nvidia-tesla-t4;count:1;install-nvidia-driver")) |
There was a problem hiding this comment.
If the user runs the pipeline with --device CPU, the pipeline will still request a Tesla T4 GPU on Dataflow because of the hardcoded .with_resource_hints call. We should conditionally apply the resource hints only when running on GPU.
inference_transform = RunInference(KeyedModelHandler(model_handler))
if device == 'cuda':
inference_transform = inference_transform.with_resource_hints(
accelerator="type:nvidia-tesla-t4;count:1;install-nvidia-driver")
predictions = (
to_infer
| 'Reshuffle' >> beam.Reshuffle()
| 'RunInference' >> inference_transform)| def ensure_pubsub_resources( | ||
| project: str, topic_path: str, subscription_path: str): | ||
| publisher = pubsub_v1.PublisherClient() | ||
| subscriber = pubsub_v1.SubscriberClient() | ||
|
|
||
| topic_name = topic_path.split("/")[-1] | ||
| subscription_name = subscription_path.split("/")[-1] | ||
|
|
||
| full_topic_path = publisher.topic_path(project, topic_name) | ||
| full_subscription_path = subscriber.subscription_path( | ||
| project, subscription_name) | ||
|
|
||
| try: | ||
| publisher.get_topic(request={"topic": full_topic_path}) | ||
| except NotFound: | ||
| publisher.create_topic(name=full_topic_path) | ||
|
|
||
| try: | ||
| subscriber.get_subscription( | ||
| request={"subscription": full_subscription_path}) | ||
| except NotFound: | ||
| subscriber.create_subscription( | ||
| name=full_subscription_path, topic=full_topic_path) | ||
|
|
||
|
|
||
| def cleanup_pubsub_resources( | ||
| project: str, topic_path: str, subscription_path: str): | ||
| publisher = pubsub_v1.PublisherClient() | ||
| subscriber = pubsub_v1.SubscriberClient() | ||
|
|
||
| topic_name = topic_path.split("/")[-1] | ||
| subscription_name = subscription_path.split("/")[-1] | ||
|
|
||
| full_topic_path = publisher.topic_path(project, topic_name) | ||
| full_subscription_path = subscriber.subscription_path( | ||
| project, subscription_name) | ||
|
|
||
| try: | ||
| subscriber.delete_subscription( | ||
| request={"subscription": full_subscription_path}) | ||
| logging.info(f"Deleted subscription: {subscription_name}") | ||
| except NotFound: | ||
| logging.info(f"Subscription already deleted: {subscription_name}") | ||
|
|
||
| try: | ||
| publisher.delete_topic(request={"topic": full_topic_path}) | ||
| logging.info(f"Deleted topic: {topic_name}") | ||
| except NotFound: | ||
| logging.info(f"Topic already deleted: {topic_name}") | ||
|
|
There was a problem hiding this comment.
Splitting the topic/subscription path and reconstructing it with project causes a mismatch if the user specified a fully qualified path in a different project. Since Beam's Pub/Sub IOs require fully qualified paths anyway, we should just use topic_path and subscription_path directly.
def ensure_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
try:
publisher.get_topic(request={"topic": topic_path})
except NotFound:
publisher.create_topic(name=topic_path)
try:
subscriber.get_subscription(
request={"subscription": subscription_path})
except NotFound:
subscriber.create_subscription(
name=subscription_path, topic=topic_path)
def cleanup_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
try:
subscriber.delete_subscription(
request={"subscription": subscription_path})
logging.info(f"Deleted subscription: {subscription_path}")
except NotFound:
logging.info(f"Subscription already deleted: {subscription_path}")
try:
publisher.delete_topic(request={"topic": topic_path})
logging.info(f"Deleted topic: {topic_path}")
except NotFound:
logging.info(f"Topic already deleted: {topic_path}")| def ensure_pubsub_resources( | ||
| project: str, topic_path: str, subscription_path: str): | ||
| publisher = pubsub_v1.PublisherClient() | ||
| subscriber = pubsub_v1.SubscriberClient() | ||
|
|
||
| topic_name = topic_path.split("/")[-1] | ||
| subscription_name = subscription_path.split("/")[-1] | ||
|
|
||
| full_topic_path = publisher.topic_path(project, topic_name) | ||
| full_subscription_path = subscriber.subscription_path( | ||
| project, subscription_name) | ||
|
|
||
| try: | ||
| publisher.get_topic(request={"topic": full_topic_path}) | ||
| except NotFound: | ||
| publisher.create_topic(name=full_topic_path) | ||
|
|
||
| try: | ||
| subscriber.get_subscription( | ||
| request={"subscription": full_subscription_path}) | ||
| except NotFound: | ||
| subscriber.create_subscription( | ||
| name=full_subscription_path, topic=full_topic_path) | ||
|
|
||
|
|
||
| def cleanup_pubsub_resources( | ||
| project: str, topic_path: str, subscription_path: str): | ||
| publisher = pubsub_v1.PublisherClient() | ||
| subscriber = pubsub_v1.SubscriberClient() | ||
|
|
||
| topic_name = topic_path.split("/")[-1] | ||
| subscription_name = subscription_path.split("/")[-1] | ||
|
|
||
| full_topic_path = publisher.topic_path(project, topic_name) | ||
| full_subscription_path = subscriber.subscription_path( | ||
| project, subscription_name) | ||
|
|
||
| try: | ||
| subscriber.delete_subscription( | ||
| request={"subscription": full_subscription_path}) | ||
| logging.info(f"Deleted subscription: {subscription_name}") | ||
| except NotFound: | ||
| logging.info(f"Subscription already deleted: {subscription_name}") | ||
|
|
||
| try: | ||
| publisher.delete_topic(request={"topic": full_topic_path}) | ||
| logging.info(f"Deleted topic: {topic_name}") | ||
| except NotFound: | ||
| logging.info(f"Topic already deleted: {topic_name}") | ||
|
|
There was a problem hiding this comment.
Splitting the topic/subscription path and reconstructing it with project causes a mismatch if the user specified a fully qualified path in a different project. Since Beam's Pub/Sub IOs require fully qualified paths anyway, we should just use topic_path and subscription_path directly.
def ensure_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
try:
publisher.get_topic(request={"topic": topic_path})
except NotFound:
publisher.create_topic(name=topic_path)
try:
subscriber.get_subscription(
request={"subscription": subscription_path})
except NotFound:
subscriber.create_subscription(
name=subscription_path, topic=topic_path)
def cleanup_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
try:
subscriber.delete_subscription(
request={"subscription": subscription_path})
logging.info(f"Deleted subscription: {subscription_path}")
except NotFound:
logging.info(f"Subscription already deleted: {subscription_path}")
try:
publisher.delete_topic(request={"topic": topic_path})
logging.info(f"Deleted topic: {topic_path}")
except NotFound:
logging.info(f"Topic already deleted: {topic_path}")| def ensure_pubsub_resources( | ||
| project: str, topic_path: str, subscription_path: str): | ||
| publisher = pubsub_v1.PublisherClient() | ||
| subscriber = pubsub_v1.SubscriberClient() | ||
|
|
||
| topic_name = topic_path.split("/")[-1] | ||
| subscription_name = subscription_path.split("/")[-1] | ||
|
|
||
| full_topic_path = publisher.topic_path(project, topic_name) | ||
| full_subscription_path = subscriber.subscription_path( | ||
| project, subscription_name) | ||
|
|
||
| try: | ||
| publisher.get_topic(request={"topic": full_topic_path}) | ||
| except NotFound: | ||
| publisher.create_topic(name=full_topic_path) | ||
|
|
||
| try: | ||
| subscriber.get_subscription( | ||
| request={"subscription": full_subscription_path}) | ||
| except NotFound: | ||
| subscriber.create_subscription( | ||
| name=full_subscription_path, topic=full_topic_path) | ||
|
|
||
|
|
||
| def cleanup_pubsub_resources( | ||
| project: str, topic_path: str, subscription_path: str): | ||
| publisher = pubsub_v1.PublisherClient() | ||
| subscriber = pubsub_v1.SubscriberClient() | ||
|
|
||
| topic_name = topic_path.split("/")[-1] | ||
| subscription_name = subscription_path.split("/")[-1] | ||
|
|
||
| full_topic_path = publisher.topic_path(project, topic_name) | ||
| full_subscription_path = subscriber.subscription_path( | ||
| project, subscription_name) | ||
|
|
||
| try: | ||
| subscriber.delete_subscription( | ||
| request={"subscription": full_subscription_path}) | ||
| logging.info(f"Deleted subscription: {subscription_name}") | ||
| except NotFound: | ||
| logging.info(f"Subscription already deleted: {subscription_name}") | ||
|
|
||
| try: | ||
| publisher.delete_topic(request={"topic": full_topic_path}) | ||
| logging.info(f"Deleted topic: {topic_name}") | ||
| except NotFound: | ||
| logging.info(f"Topic already deleted: {topic_name}") | ||
|
|
There was a problem hiding this comment.
Splitting the topic/subscription path and reconstructing it with project causes a mismatch if the user specified a fully qualified path in a different project. Since Beam's Pub/Sub IOs require fully qualified paths anyway, we should just use topic_path and subscription_path directly.
def ensure_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
try:
publisher.get_topic(request={"topic": topic_path})
except NotFound:
publisher.create_topic(name=topic_path)
try:
subscriber.get_subscription(
request={"subscription": subscription_path})
except NotFound:
subscriber.create_subscription(
name=subscription_path, topic=topic_path)
def cleanup_pubsub_resources(
project: str, topic_path: str, subscription_path: str):
publisher = pubsub_v1.PublisherClient()
subscriber = pubsub_v1.SubscriberClient()
try:
subscriber.delete_subscription(
request={"subscription": subscription_path})
logging.info(f"Deleted subscription: {subscription_path}")
except NotFound:
logging.info(f"Subscription already deleted: {subscription_path}")
try:
publisher.delete_topic(request={"topic": topic_path})
logging.info(f"Deleted topic: {topic_path}")
except NotFound:
logging.info(f"Topic already deleted: {topic_path}")| def run_load_pipeline(known_args, pipeline_args): | ||
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | ||
| # enforce smaller/CPU-only defaults for feeder | ||
| override_or_add(pipeline_args, '--device', 'CPU') |
There was a problem hiding this comment.
pipeline_args is a shared list passed from the main thread. Modifying it in-place in a background thread is a race condition risk. We should copy it first using pipeline_args = list(pipeline_args).
| def run_load_pipeline(known_args, pipeline_args): | |
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | |
| # enforce smaller/CPU-only defaults for feeder | |
| override_or_add(pipeline_args, '--device', 'CPU') | |
| def run_load_pipeline(known_args, pipeline_args): | |
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | |
| pipeline_args = list(pipeline_args) | |
| # enforce smaller/CPU-only defaults for feeder | |
| override_or_add(pipeline_args, '--device', 'CPU') |
| def run_load_pipeline(known_args, pipeline_args): | ||
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | ||
| # enforce smaller/CPU-only defaults for feeder | ||
| override_or_add(pipeline_args, '--device', 'CPU') |
There was a problem hiding this comment.
pipeline_args is a shared list passed from the main thread. Modifying it in-place in a background thread is a race condition risk. We should copy it first using pipeline_args = list(pipeline_args).
| def run_load_pipeline(known_args, pipeline_args): | |
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | |
| # enforce smaller/CPU-only defaults for feeder | |
| override_or_add(pipeline_args, '--device', 'CPU') | |
| def run_load_pipeline(known_args, pipeline_args): | |
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | |
| pipeline_args = list(pipeline_args) | |
| # enforce smaller/CPU-only defaults for feeder | |
| override_or_add(pipeline_args, '--device', 'CPU') |
| def run_load_pipeline(known_args, pipeline_args): | ||
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | ||
| # enforce smaller/CPU-only defaults for feeder | ||
| override_or_add(pipeline_args, '--device', 'CPU') |
There was a problem hiding this comment.
pipeline_args is a shared list passed from the main thread. Modifying it in-place in a background thread is a race condition risk. We should copy it first using pipeline_args = list(pipeline_args).
| def run_load_pipeline(known_args, pipeline_args): | |
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | |
| # enforce smaller/CPU-only defaults for feeder | |
| override_or_add(pipeline_args, '--device', 'CPU') | |
| def run_load_pipeline(known_args, pipeline_args): | |
| """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" | |
| pipeline_args = list(pipeline_args) | |
| # enforce smaller/CPU-only defaults for feeder | |
| override_or_add(pipeline_args, '--device', 'CPU') |
| image_inputs = processor( | ||
| images=images, | ||
| return_tensors="pt", | ||
| ) | ||
| image_inputs = { | ||
| k: (v.to(self.device) if torch.is_tensor(v) else v) | ||
| for k, v in image_inputs.items() | ||
| } |
There was a problem hiding this comment.
Hugging Face BatchEncoding / BatchFeature objects have a built-in .to(device) method that cleanly moves all internal tensors to the specified device. We can replace the dict comprehension with .to(self.device).
| image_inputs = processor( | |
| images=images, | |
| return_tensors="pt", | |
| ) | |
| image_inputs = { | |
| k: (v.to(self.device) if torch.is_tensor(v) else v) | |
| for k, v in image_inputs.items() | |
| } | |
| image_inputs = processor( | |
| images=images, | |
| return_tensors="pt", | |
| ).to(self.device) |
| text_inputs = processor( | ||
| text=texts, | ||
| return_tensors="pt", | ||
| padding=True, | ||
| truncation=True, | ||
| ) | ||
| text_inputs = { | ||
| k: (v.to(self.device) if torch.is_tensor(v) else v) | ||
| for k, v in text_inputs.items() | ||
| } |
There was a problem hiding this comment.
Hugging Face BatchEncoding / BatchFeature objects have a built-in .to(device) method that cleanly moves all internal tensors to the specified device. We can replace the dict comprehension with .to(self.device).
| text_inputs = processor( | |
| text=texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ) | |
| text_inputs = { | |
| k: (v.to(self.device) if torch.is_tensor(v) else v) | |
| for k, v in text_inputs.items() | |
| } | |
| text_inputs = processor( | |
| text=texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| ).to(self.device) |
Please add a meaningful description for your change here
Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, commentfixes #<ISSUE NUMBER>instead.CHANGES.mdwith noteworthy changes.See the Contributor Guide for more tips on how to make review process smoother.
To check the build health, please visit https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md
GitHub Actions Tests Status (on master branch)
See CI.md for more information about GitHub Actions CI or the workflows README to see a list of phrases to trigger workflows.